from mynumpy import *

import seaborn as sns
from estimator_naive      import estimator_naive
from estimator_antithetic import estimator_antithetic
from estimator_iid        import estimator_iid
from estimator_stratified import estimator_stratified
# from estimator_shotgun    import estimator_shotgun
# from estimator_chord      import estimator_chord
# from estimator_tied       import estimator_tied
# from estimator_copula     import estimator_copula
# from estimator_qmc        import estimator_qmc
#from targets import targets
import targets
from matplotlib import pyplot as plt
import ezopt
from dists import mydist
import uniformdist
import util

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from IPython import embed

z = arange(-10,10,.01)

def tostring(a):
    return format(a, '.3f')

def elbo(est,w,omega=None):
    if omega is None:
        omega = est.sample_omega(10000*expense)
    rez = mean(est.logR(omega,w))
    if isnan(rez) or np.isinf(rez):
        rez = -1e10
    return rez

def opt_elbo(est):
    nsamps2opt  = 10000*expense
    #seed(1)
    omega = est.sample_omega(nsamps2opt)
    w0 = 1e-2*randn(est.w_dim)
    method = 'L-BFGS-B'
    w = ezopt.opt1(lambda w:elbo(est,w,omega),w0,max=True,method=method)    
    return w

botbuf = .01
def init_dist_fig():
    plt.figure(figsize=(5,1.25))
    sns.set_style('white')
    plt.axis([-10,10,-botbuf-.01,.45])
    plt.axis('off')
    plt.subplots_adjust(bottom=0)
    plt.subplots_adjust(top   =.99)
    plt.subplots_adjust(left  =0.02)
    plt.subplots_adjust(right =0.98)

def label_fig(text1,text2=None):
    plt.text(-7,.4,text1)
    if text2 is not None:
        plt.text(-7,.35,text2)

def make_leg(y_adjust=0):
    plt.legend(loc='upper left',bbox_to_anchor=[0, .7+y_adjust, .1, .1])

seed(1)

mynorm = mydist(autograd.scipy.stats.norm,scipy.stats.norm)

expense = 100*5
bw      = .025 # for KDE
lw      = 1.0+.5 # linewidth
alpha   = .6
ms      = 3 # markersize
color_base = 'gray'
line_sampled = '--'
#color_dist = 'darkred'
color_samp = 'black'
marker_samp = '.'
kde_samps  = 10000*expense
#color_dist = sns.color_palette("viridis", n_colors=4)
#color_dist = sns.hls_palette(4, l=.3, s=.8)
#color_dist = sns.color_palette()
color_dist = sns.color_palette("dark", 4)

t = targets.target_imbalanced()

z = arange(-10,10+1e-9,.01*5)

# first, picture the target itself
init_dist_fig()
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
#label_fig('Target')
make_leg()
plt.savefig('target.pdf',transparent=True)
plt.close()

# do good-old fashioned VI
init_dist_fig()
est_naive = estimator_naive(t,mynorm)
w = opt_elbo(est_naive)
omega = array([.15])
zs    = est_naive.sample_z(omega,w)
elbo_naive = elbo(est_naive,w)
#plt.plot(z,t.p(z),color=color_base)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,est_naive.a(z,w),color=color_dist[0],label='$q(z)$, naive',alpha=alpha,linewidth=lw)
plt.plot(zs,est_naive.a(zs,w),marker_samp,color=color_samp,markersize=ms) #,label='$z_1$'
make_leg()
plt.text(-9.5,.35,'$\mathbb{E} \log R = ' + tostring(elbo_naive) + '$')
plt.savefig('vi.pdf',transparent=True)
plt.close()
# save the density for later
p_naive = est_naive.a(z,w)

# do antithetic VI
init_dist_fig()
est_anti = estimator_antithetic(est_naive)
w = opt_elbo(est_anti)
p0_anti = est_naive.a(z,w)
elbo_anti = elbo(est_anti,w)
my_elbo0 = elbo(est_naive,w)
print('antithetic elbo0',my_elbo0)
omega = array([.15,.85])
zs    = est_naive.sample_z(omega,w)
#plt.plot(z,t.p(z),color=color_base)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,est_naive.a(z,w),color=color_dist[1],label='$q(z)$, antithetic',alpha=alpha,linewidth=lw)
plt.plot(zs,est_naive.a(zs,w),marker_samp,color=color_samp,markersize=ms) #,label='$z_1,z_2$'
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_anti) + '$')
#plt.text(-9.5,.18,'($\mathbb{E} \log R = ' + tostring(my_elbo0) + '$)')
# make thin lines for mid
for omega in [.5]:
    zs = est_naive.sample_z(omega,w)
    p = est_naive.a(zs,w)
    plt.plot([zs,zs],[0,p],':',color=color_dist[1],linewidth=0.5)
make_leg()
plt.savefig('anti.pdf',transparent=True)
plt.close()
# save the density for later
## sample some zs
omega = est_anti.sample_omega(kde_samps)
zs = est_anti.sample_z(omega,w)
p_anti = scipy.stats.gaussian_kde(zs,bw)(z)

# do stratified VI
init_dist_fig()
est_strat = estimator_stratified(est_naive,4)
w = opt_elbo(est_strat)
p0_strat = est_naive.a(z,w)
elbo_strat = elbo(est_strat,w)
omega = array([.15,.36,.7,.97])
zs    = est_naive.sample_z(omega,w)
## sample some zs
##omega = est_strat.sample_omega(kde_samps)
##zs = est_strat.sample_z(omega,w)
#plt.plot(z,t.p(z),color=color_base)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,est_naive.a(z,w),'-',color=color_dist[3],label='$q(z)$, stratified',alpha=alpha,linewidth=lw)
plt.plot(zs,est_naive.a(zs,w),marker_samp,color=color_samp,markersize=ms) #,label='$z_1,\cdots,z_4$'
# make thin lines for strata
for omega in [.25,.5,.75]:
    zs = est_naive.sample_z(omega,w)
    p = est_naive.a(zs,w)
    plt.plot([zs,zs],[0,p],color=color_dist[3],linewidth=0.5)
#label_fig('Stratified VI','ELBO gap = ' + str(round(-gap,3)))
make_leg()
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_strat) + '$')
plt.savefig('strat.pdf',transparent=True)
plt.close()
omega = est_anti.sample_omega(kde_samps)
zs = est_strat.sample_z(omega,w)
p_strat = scipy.stats.gaussian_kde(zs,bw)(z)

# do antithetic inside of stratified
init_dist_fig()
est_antistrat = estimator_antithetic(est_strat)
w = opt_elbo(est_antistrat)
p0_anti_strat = est_naive.a(z,w)
elbo_anti_strat = elbo(est_antistrat,w)
#gap = elbo(est_anti,w)
# sample some zs
#omega = est_anti.sample_omega(kde_samps)
omega0 = array([.15,.36,.7,.97])
centers = array([0,.25,.5,.75])+.25/2
omega1  = centers-(omega0-centers)
omega   = np.concatenate([omega0,omega1])
zs = est_naive.sample_z(omega,w)
#plt.plot(z,t.p(z),color=color_base)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,est_naive.a(z,w),label='$q(z)$, antithetic within strata',color=color_dist[2],alpha=alpha,linewidth=lw)
plt.plot(zs,est_naive.a(zs,w),marker_samp,color=color_samp,markersize=ms) #label='$z_1,\cdots,z_8$'
for omega in [0,.25,.5,.75]:
    zs = est_naive.sample_z(omega,w)
    p = est_naive.a(zs,w)
    plt.plot([zs,zs],[0,p],color=color_dist[2],linewidth=0.5)
    zs = est_naive.sample_z(omega+.125,w)
    p = est_naive.a(zs,w)
    plt.plot([zs,zs],[0,p],':',color=color_dist[2],linewidth=0.5)
#label_fig('Antithetic Stratified VI','ELBO gap = ' + str(round(-gap,3)))
make_leg()
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_anti_strat) + '$')
plt.savefig('anti_strat.pdf',transparent=True)
plt.close()
omega = est_antistrat.sample_omega(kde_samps)
zs = est_antistrat.sample_z(omega,w)
p_antistrat = scipy.stats.gaussian_kde(zs,bw)(z)

# now compare all the densities
init_dist_fig()
#plt.plot(z,t.p(z)     ,color=color_base,linewidth=lw)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,p_anti     ,line_sampled,color=color_dist[1],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, antithetic')
#plt.plot(z,p0_anti    ,color=color_dist[1],alpha=.2   ,linewidth=lw)
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_anti) + '$')
make_leg()
plt.savefig('sampled_anti.pdf',transparent=True)
plt.close()

# now compare all the densities
init_dist_fig()
#plt.plot(z,t.p(z)     ,color=color_base,linewidth=lw)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,p_strat    ,line_sampled,color=color_dist[3],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, stratified')
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_strat) + '$')
make_leg()
plt.savefig('sampled_strat.pdf',transparent=True)
plt.close()

# now compare all the densities
init_dist_fig()
#plt.plot(z,t.p(z)     ,color=color_base,linewidth=lw)
plt.plot(z,t.p(z),color=color_base,label='$p(z,x)$',linewidth=lw)
plt.plot(z,p_antistrat,line_sampled,color=color_dist[2],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, antithetic within strata')
plt.text(-9.5,.35,'$\mathbb{E} \log R\' = ' + tostring(elbo_anti_strat) + '$')
make_leg()
plt.savefig('sampled_anti_strat.pdf',transparent=True)
plt.close()

# now compare all the densities
init_dist_fig()
plt.plot(z,t.p(z)      ,color=color_base,linewidth=lw)
plt.plot(z,p_naive      ,color=color_dist[0],alpha=alpha,linewidth=lw,label='$q(z)$, naive')
plt.plot(z,p0_anti      ,color=color_dist[1],alpha=alpha,linewidth=lw,label='$q(z)$, antithetic')
plt.plot(z,p0_strat     ,color=color_dist[3],alpha=alpha,linewidth=lw,label='$q(z)$, stratified')
plt.plot(z,p0_anti_strat,color=color_dist[2],alpha=alpha,linewidth=lw,label='$q(z)$, antithetic within strata')
make_leg(.25)
plt.savefig('all.pdf',transparent=True)

# now compare all the densities
init_dist_fig()
plt.plot(z,t.p(z)     ,color=color_base,linewidth=lw)
plt.plot(z,p_naive                  ,color=color_dist[0],alpha=alpha,linewidth=lw,label='$q(z)$, naive')
plt.plot(z,p_anti      ,line_sampled,color=color_dist[1],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, antithetic')
plt.plot(z,p_strat     ,line_sampled,color=color_dist[3],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, stratified')
plt.plot(z,p_antistrat,line_sampled,color=color_dist[2],alpha=alpha,linewidth=lw,label='$q_\mathrm{MC}(z|x)$, antithetic within strata')
make_leg(.25)
plt.savefig('sampled_all.pdf',transparent=True)


# gaps_iid       = []
# gaps_strat     = []
# gaps_antistrat = []

# binss = arange(2,20,2)
# for bins in binss:
#     est_iid = estimator_iid(est_naive,bins)
#     w = opt_elbo(est_iid)
#     gaps_iid.append(-elbo(est_iid,w))

#     est_strat = estimator_stratified(est_naive,bins)
#     w = opt_elbo(est_strat)
#     gaps_strat.append(-elbo(est_strat,w))

#     est_strat = estimator_stratified(est_naive,bins//2)
#     est_anti  = estimator_antithetic(est_strat)
#     w = opt_elbo(est_anti)
#     gaps_antistrat.append(-elbo(est_anti,w))

    #print('bins',bins,'gap1',gap1,'gap2',gap2)

# plt.plot(binss,gaps_iid      ,label='iid')
# plt.plot(binss,gaps_strat    ,label='strat')
# plt.plot(binss,gaps_antistrat,label='anti-strat')
# plt.xlim([1,20])
# plt.ylim([0,.1])
# plt.legend()
# plt.xlabel('# samples')
# plt.ylabel('elbo gap')
# plt.savefig('elbo_gap.pdf')
# emb()